[Gemma 4] Add multimodal support (apply_liger_kernel_to_gemma4 for Gemma4ForConditionalGeneration)#1203
Conversation
|
Hi @dvdimitrov13 . Do you have any update on this pr? |
Follow-up to linkedin#1196 — adds the multimodal entry point `apply_liger_kernel_to_gemma4` for `Gemma4ForConditionalGeneration`. The (B, T, V) bf16 logits tensor on Gemma 4 multimodal training is ~17 GB at T=8192 / vocab=262,144 (and ~34 GB once the loss path upcasts to fp32), OOMing 96 GB cards on Gemma4ForConditionalGeneration SFT. Routing loss through `LigerForCausalLMLoss` materializes only the loss scalar. Shape — unified entry point dispatching on class (per @Mecoli1219's preference in linkedin#1186): - `Gemma4ForConditionalGeneration` → installs `multimodal_forward`, class-level RMSNorm + GeGLU swaps, recurses into `model.model.language_model` for instance-level patches. - `Gemma4ForCausalLM` / `Gemma4TextForCausalLM` / `Gemma4TextModel` → routes to `apply_liger_kernel_to_gemma4_text`. - Registry: adds `"gemma4"` alongside existing `"gemma4_text"`. Drive-by: replaces `cls is not None` with `isinstance(cls, type)` in `apply_liger_kernel_to_gemma4_text`'s `causal_lm_types` filter. The dormant bug fires when the multimodal dispatcher recurses into the text path with a `Gemma4TextModel` — MagicMock (from `unittest.mock.patch`) auto-creates a `Gemma4TextForCausalLM` attribute that slipped past the `is not None` filter and landed in an `isinstance(tuple)`, raising TypeError. Closes the multimodal half of linkedin#1186. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
73f941f to
d5b0578
Compare
|
Apologies for the silence @Mecoli1219 - relatively new to open source contribution, and I misread the draft convention as "waiting for maintainer pre-review before final polish", which I now realise isn't how it works on external repos. Rebased onto current main as a single squashed commit (the previous 31 commits had drifted off the merge base from before #1196 landed), conflicts resolved, marked ready for review. |
Mecoli1219
left a comment
There was a problem hiding this comment.
Thanks @dvdimitrov13, and welcome to the open-source community! Nice first contribution!
Left some inline comments. One bigger-picture thought: we should plan to support Gemma4's vision and audio towers (Gemma4VisionModel / Gemma4AudioModel) in a follow-up. Happy to scope that as a separate PR. Just wanted to flag it so we don't lose track.
| # `if config.<m>_config is not None`, so a None-towers model still | ||
| # constructs as Gemma4ForConditionalGeneration and exercises the | ||
| # multimodal forward we're patching. The towers themselves are | ||
| # polymorphic (AutoModel.from_config) and not in this PR's scope. |
There was a problem hiding this comment.
Gemma4 ships dedicated Gemma4VisionModel & Gemma4AudioModel classes (concrete, not polymorphic like gemma3's SigLIP tower). RMSNorm and the GeGLU/SwiGLU-style MLPs in those towers should be a near-direct port from the text patches — worth doing in a follow-up.
There was a problem hiding this comment.
Will open a follow-up PR after this lands. Scope as I see it: RMSNorm + GeGLU/SwiGLU on Gemma4VisionModel and Gemma4AudioModel, the multimodal projector norms (analogous to gemma3's mm_soft_emb_norm), and the audio convergence test scaffolding I deferred from this PR.
Happy to open a tracking issue first if you'd prefer that over a draft PR.
| print("Liger kernel patches have been reverted.") | ||
|
|
||
|
|
||
| def revert_liger_kernel_to_gemma4(model_config: MiniModelConfig): |
There was a problem hiding this comment.
Can we have a convergence test for Gemma4 multimodal model?
There was a problem hiding this comment.
Added in f23039f. Wanted to be transparent about what came up during validation.
What's in the commit
- New
test/resources/fake_configs/Google/Gemma4/gemma-4-e4b-it/tokenizer_config.json. The chat template emits<image_soft_token>(not<start_of_image>like gemma3):Gemma4Processor.__call__expandsimage_tokenplaceholders to<boi><image_token>*n<eoi>, a different pattern from gemma3's boi-token replacement. mini_gemma4entries in bf16 and fp32test_mini_models_multimodal.py. Vision config haspatch_size=16(matchingGemma4ImageProcessor's hardcoded default - the processor doesn't accept a patch_size kwarg).audio_config=None, so audio coverage stays in the vision/audio tower follow-up.apply_liger_kernel_to_gemma4now acceptslayer_norm: bool = Falseas a no-op kwarg. The convergence framework passeslayer_norm=Trueby default for any model not in its exclusion list, and Gemma4 vision uses RMSNorm. Accept-and-no-op is consistent with the deferred vision/audio tower scope.
Validation, asking for guidance
I spun up an RTX 3090 (Ampere) Vast.ai instance to validate locally. After fixing the items above, the test runs end-to-end, but the numerical comparison fails. So does the baseline mini_gemma3 in both bf16 and fp32 on the same env. My read is this is the same env-sensitivity the original PR ran into with mini_gemma4_text on Blackwell bf16.
Env I used: PyTorch 2.5.1 (cuda12.4 image), transformers 5.7.0 from PyPI. I couldn't pull transformers main because it imports CPUOffloadPolicy from torch.distributed.fsdp which requires torch >= 2.6.
A couple of paths forward:
- Land as-is and let your CI environment validate (assuming it's the configuration where
mini_gemma3passes today). - Happy to retry on a different GPU + torch + transformers combination if you have a known-green spec.
I'd appreciate your call.
| text_classes = tuple( | ||
| cls for cls in (Gemma4ForCausalLM, Gemma4TextForCausalLM, Gemma4TextModel) if isinstance(cls, type) | ||
| ) | ||
| if isinstance(model, text_classes): |
There was a problem hiding this comment.
Will this path ever happen with _apply_liger_kernel? If not, we can throw error like how gemma3 did.
There was a problem hiding this comment.
Confirmed - the routing branch is dead via _apply_liger_kernel. Text variants dispatch through the "gemma4_text" registry entry directly, never reaching "gemma4". Matched gemma3's pattern in d0f7516: TypeError on a non-ConditionalGeneration instance.
The isinstance(cls, type) filter in apply_liger_kernel_to_gemma4_text stays - the recursive call from the multimodal path still hits it with a Gemma4TextModel under unittest.mock.patch.
Per @Mecoli1219's review on linkedin#1203: the text-routing branch in apply_liger_kernel_to_gemma4 is dead code via _apply_liger_kernel. The framework dispatches text variants (Gemma4ForCausalLM / Gemma4TextForCausalLM / Gemma4TextModel) through the "gemma4_text" registry entry directly, never reaching the multimodal "gemma4" entry. Match gemma3's pattern: raise TypeError on a non- ConditionalGeneration instance. The drive-by isinstance(cls, type) filter in apply_liger_kernel_to_gemma4_text stays - the recursive call from this function still hits the text path with a Gemma4TextModel instance under unittest.mock.patch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per @Mecoli1219's review on linkedin#1203 (comment on test/utils.py line 508 re: revert_liger_kernel_to_gemma4): adds the corresponding mini convergence test for the multimodal entry point, mirroring mini_gemma3's pattern. Scope: image+text path through Gemma4ForConditionalGeneration. - New test/resources/fake_configs/Google/Gemma4/gemma-4-e4b-it/ tokenizer_config.json. Chat template emits <image_soft_token>; Gemma4Processor.__call__ expands image_token placeholders to <boi><image_token>*n<eoi>, which differs from gemma3's boi-token pattern. - bf16 and fp32 test_mini_models_multimodal.py: GEMMA4_AVAILABLE block, MINI_MODEL_SETUPS["mini_gemma4"] with vision config (patch_size=16 matches Gemma4ImageProcessor's hardcoded default; audio_config=None), create_processor branch, pytest.param with gemma3-matching tolerances. - apply_liger_kernel_to_gemma4 accepts layer_norm: bool = False as a no-op kwarg. The convergence framework defaults to passing layer_norm=True; Gemma4 vision uses RMSNorm so accept-and-no-op is consistent with the deferred vision/audio tower scope. Audio coverage (Gemma4AudioModel + procedural audio generation in the convergence harness) is intentionally deferred to the vision/audio tower follow-up PR; the multimodal forward we patch exercises the LM head FLCE which is the OOM unblock this PR ships. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mecoli1219
left a comment
There was a problem hiding this comment.
Hi @dvdimitrov13, I tested locally on H100. There are some numerical drifts that cause failures in the convergence test, but they're within the range we see for other multimodal models in bf16. I think it's fine to merge as-is.
Thanks again for the contribution! If you'd like to take on the vision/audio tower follow-up, feel free to open a PR and ping me when it's ready. Looking forward to your future contributions!
Summary
Follow-up to #1196 (the text path) — adds
apply_liger_kernel_to_gemma4forGemma4ForConditionalGeneration(multimodal class; includes E2B / E4B / E4B-it which are loaded byAutoModelForCausalLMas the multimodal class even when only text is being trained).Closes the multimodal half of #1186.
Why
Gemma 4's text vocab is 262,144. Without FLCE the (B, T, V) bf16 logits tensor is ~17 GB at T=8192 (and ~34 GB once the loss path upcasts to fp32 for cross-entropy), which OOMs even 96 GB cards on
Gemma4ForConditionalGenerationSFT — the OOM that originally motivated #1186. Routing loss throughLigerForCausalLMLossmaterializes only the loss scalar.Shape
A single unified entry point that dispatches on class, per @Mecoli1219's preference in #1186:
apply_liger_kernel_to_gemma4(model=Gemma4ForConditionalGeneration_instance)— multimodal path. Class-level RMSNorm + GeGLU swaps viaapply_liger_kernel_to_gemma4_text, FLCE forward via the newmultimodal_forward, recurses intomodel.model.language_modelfor instance-level patches.apply_liger_kernel_to_gemma4(model=Gemma4ForCausalLM_instance)— routes toapply_liger_kernel_to_gemma4_textfor backwards compatibility, so the same entry point works for either shape."gemma4": apply_liger_kernel_to_gemma4alongside the existing"gemma4_text".Drive-by fixes
Two
isinstance(model, tuple_with_None_filter)sites (one in our new dispatcher, one in the existingapply_liger_kernel_to_gemma4_text) raisedTypeError: isinstance() arg 2 must be a type, a tuple of types, or a unionwhen called underwith patch("transformers.models.gemma4.modeling_gemma4"):getattr(MagicMock_module, "Gemma4TextForCausalLM", None)returns a MagicMock (notNone) becauseMagicMockauto-creates attributes, so thecls is not Nonefilter let it slip into theisinstancetuple. The text-path version was dormant — its existing test passes aGemma4ForCausalLMwhich short-circuits theisinstancematch before reaching the bad entry — but the multimodal recursive call into the text path passes aGemma4TextModel, so no early match. Both sites now useisinstance(cls, type)as the filter.Out of scope (deferred)
AutoModel.from_config(config.{vision,audio}_config), so the module classes are polymorphic. Out of scope here; FLCE on the LM head is what unblocks training OOM.Gemma4MultimodalEmbedder/ projector norms. Analogous to gemma3'smm_soft_emb_normpatching. Skipped for the same minimal-surface reason.google/gemma-4-E4B-it.Gemma4TextExperts). Guarded out via the sameenable_moe_blockcheck used by the text path in [Gemma 4] Add apply_liger_kernel_to_gemma4_text (dense text, 31B-targeted) #1196.mini_gemma4). Would need newtest/resources/fake_configs/Google/Gemma4/...scaffolding for an image / audio processor — PR [Gemma 4] Add apply_liger_kernel_to_gemma4_text (dense text, 31B-targeted) #1196 followed the same pattern (only addedmini_gemma4_textto the non-multimodal convergence files). Happy to add it as a follow-up if you'd prefer it bundled here — let me know.Testing Done
Hardware
huggingface/transformers@main— gemma4 requires ≥ 5.5.0)End-to-end numerical equivalence on real
google/gemma-4-E4B-itVerified before authoring this PR with our internal
verify_patch_equivalence.py(same shape as #1196's verification):< 5e-30.0016> 99 %< 5e-30.0016~3.5e-2> 0.9999Liger-Kernel test gates
make checkstyle—All checks passed!, 267 files already formattedmake test— see logmake test-convergence— see logs (per file)make test3131 passed, 903 skipped, 12 xfailed, 3 failedin 35:42.Our two new unit tests pass cleanly:
All six
LigerGEGLUMLPForGemma4edge-case tests added by #1196 also pass.The 3 unrelated failures are pre-existing on the parent branch (untouched by this PR) and reproduce on
main:make test-convergence(per file)fp32/test_mini_models.pyfp32/test_mini_models_multimodal.pymini_qwen2_vl)fp32/test_mini_models_with_logits.pybf16/test_mini_models.pymini_llama4,mini_gemma4_text)bf16/test_mini_models_multimodal.pymini_qwen2_vl,mini_llama4)bf16/test_mini_models_with_logits.pymini_llama4,mini_qwen3_moe)mini_gemma4_textpasses inbf16/test_mini_models_with_logits.pyandfp32/test_mini_models.py(PR #1196's text path). It fails only inbf16/test_mini_models.pywith the same Blackwell bf16 logprob drift @eqy and reviewers flagged on #1196 (review commentr4321013177); this is in PR #1196's territory, not introduced by this multimodal patch. The other failures (mini_llama4,mini_qwen2_vl,mini_qwen3_moe) are also in models we don't touch and are pre-existing test instability on consumer Blackwell.cc @Mecoli1219 @lardinator @ruilin-gif
🤖 Drafted with Claude Code (Claude Opus 4.7), reviewed and posted by me.